%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%
% function [Ps,Pm] = do_frank_wolfe(Ps,A,B,num_updates)
%
% Performs Frank_Wolfe updates on the doubly stochastic
% matrix Ps to obtain a higher score.
% Pm is updated as the projection of Ps onto the space of
% permutation matrices.
%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [Ps,Pm] = do_frank_wolfe(Ps,A,B,num_updates)

tStart = tic;
secs_per_minute = 60;
fprintf(1,'%%\n%% Frank-Wolfe updates:\n');
fprintf(1,'%%  iter    vertex   simplex   tMin   tMatch\n');
tMin = toc(tStart)/secs_per_minute;
Pm = permutation_match(Ps);
if (isequal(Pm,Ps))
  scorePm = full(sum(min(A*Pm,Pm*B),'all'));
  fprintf(1,'     %02d   %07d   %07d   %04.1f   %04.1f\n',0,scorePm,scorePm,tMin,0);
end

for iter=1:num_updates
  Gs = compute_gradient(Ps,A,B);
  [Ps,tMatch] = frank_wolfe_update(Ps,Gs,Pm,A,B);
  Pm = permutation_match(Ps);
  scorePm = full(sum(min(A*Pm,Pm*B),'all'));
  scorePs = round(full(0.5*(sum(Gs.*Ps,'all'))));
  tMin = toc(tStart)/secs_per_minute;
  fprintf(1,'     %02d   %07d   %07d   %04.1f   %04.1f\n',iter,scorePm,scorePs,tMin,tMatch);
end

end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

function [P,tMatch] = frank_wolfe_update(P,Gp,Pm,A,B)

% PROJECT GRADIENT TO PERMUTATION MATRIX
tStart = tic;
Q = permutation_match(Gp,Pm);
tMatch = toc(tStart);

% DONE?
if (isequal(P,Q))
  return;
end

% COMPUTE STEP SIZE
Gq = compute_gradient(Q,A,B);
numer = full(sum(Gp.*(Q-P),'all'));
denom = full(sum((Gq-Gp).*(Q-P),'all'));
step = -numer/denom;

% STAY WITHIN SIMPLEX
step(step>1) = 1;     % CONCAVE BUT MAXIMUM AT step>1
step(step<0) = 1;     % CONVEX WITH MINIMUM AT step<0

% INTERPOLATE
P = P + step*(Q-P);
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%